import warnings
import numpy as np
import scipy as sp
from tqdm import tqdm
from scipy.integrate import odeint

warnings.filterwarnings('ignore')

G = 6.67408e-11

m_nd = 1.989e+30
r_nd = 5.326e+12
v_nd = 30000
t_nd = 79.91 * 365 * 24 * 3600 * 0.51

K1 = G * t_nd * m_nd / (np.square(r_nd) * v_nd)
K2 = v_nd * t_nd / r_nd

def add_noise(arr: np.ndarray) -> np.ndarray:
    noise = 1 + (np.random.random(arr.shape) - 0.5) / 100
    return noise * arr

def threeBodyEquations(w, t, G, m1, m2, m3):
    r1, r2, r3, v1, v2, v3 = w.reshape(6, -1)

    r12 = sp.linalg.norm(r2 - r1)
    r13 = sp.linalg.norm(r3 - r1)
    r23 = sp.linalg.norm(r3 - r2)
   
    dv1bydt = K1 * m2 * (r2 - r1) / r12**3 + K1 * m3 * (r3 - r1) / r13**3
    dv2bydt = K1 * m1 * (r1 - r2) / r12**3 + K1 * m3 * (r3 - r2) / r23**3
    dv3bydt = K1 * m1 * (r1 - r3) / r13**3 + K1 * m2 * (r2 - r3) / r23**3
    dr1bydt = K2 * v1
    dr2bydt = K2 * v2
    dr3bydt = K2 * v3

    r12_derivs = sp.concatenate((dr1bydt, dr2bydt))
    r_derivs = sp.concatenate((r12_derivs, dr3bydt))
    v12_derivs = sp.concatenate((dv1bydt, dv2bydt))
    v_derivs = sp.concatenate((v12_derivs, dv3bydt))
    derivs = sp.concatenate((r_derivs, v_derivs))
    
    return derivs

def generate_traj(n_time: int = 8, n_take: int = 8, n_pred: int = 3) -> np.ndarray:
    m1 = 1.1
    m2 = 0.9
    m3 = 1.0

    r1 = sp.array([0, 0, 0], dtype="float64")
    r2 = sp.array([1.732, 3, 0], dtype="float64")
    r3 = sp.array([3.464, 0, 0], dtype="float64")

    v1 = (np.random.random(3) - 0.5) / 10
    v2 = (np.random.random(3) - 0.5) / 10
    v3 = (np.random.random(3) - 0.5) / 10

    init_params = sp.array([r1, r2, r3, v1, v2, v3]).flatten()
    time_span = sp.linspace(0, 5, n_time + n_pred)
    t = np.sort(np.random.choice(np.arange(n_time - 1), size = n_take - 1, replace=False))
    t = np.append(t, n_time - 1)

    three_body_sol = np.array(odeint(threeBodyEquations, init_params, time_span, args=(G, m1, m2, m3))[:, :9])

    return add_noise(three_body_sol[t]), add_noise(three_body_sol[-n_pred:]), t

def create_dataset(num: int, n_time: int, n_take: int, n_pred: int, save_name: str) -> None:
    trials_x = []
    trials_y = []
    trials_t = []

    for _ in tqdm(range(int(num))):
        x, y, t = generate_traj(n_time, n_take, n_pred)
        trials_x.append(x)
        trials_y.append(y)
        trials_t.append(t)

    trials_x = np.array(trials_x)
    trials_y = np.array(trials_y)
    trials_t = np.array(trials_t)

    if save_name[-4:] == '.npy':
        save_name = save_name[:-4]

    if save_name[-4:] != '.npy':
        np.save(save_name + '_x.npy', trials_x)
        np.save(save_name + '_y.npy', trials_y)
        np.save(save_name + '_t.npy', trials_t)

if __name__ == '__main__':
    create_dataset(5e4, 8, 6, 3, 'train')
    create_dataset(5e3, 8, 6, 3, 'test')